import Optimizer
import random

# this optimizer performs a generational genetic optimization over a given number of generations
# it uses uniform crossover
# generation size equals the number of bits in the state
class ParticleSwarm(Optimizer.Optimizer):
    # num_batches: how many batchs to iterate over
    # states_per_batch: how many states per batch
    def __init__(self, inertia, global_pull, noisy_movement, num_batches, states_per_batch, num_bits, number_top_states, characterizer, initial_states = None, verbose=False):
        super(ParticleSwarm,self).__init__(number_top_states,characterizer,verbose)
        self.verbose = verbose

        self.num_batches = num_batches

        self.num_bits = num_bits
        self.maximum_state =  2**num_bits-1

        self.batch_counter = 0

        self.inertia = inertia
        self.global_pull = global_pull
        self.noisy_movement = noisy_movement

        self.num_particles = states_per_batch
        self.particles = []
        for i in range(self.num_particles):
            # randomize initial positions
            if initial_states[0] is None:
                self.particles += [Particle(self.num_bits,inertia,global_pull,noisy_movement)]
            # initial positions are assigned so particles are likely to be at initial positions
            else:
                position = []
                if i < len(initial_states):
                    discrete_position = initial_states[i]
                else:
                    rand_samp = random.sample(initial_states, 1)
                    discrete_position = rand_samp[0]
                bin_position = bin(discrete_position)
                bit_count = len(bin_position)-2
                if bit_count < self.num_bits:
                    missing_bits = self.num_bits - bit_count
                    bin_position = bin_position[:2] + '0'*missing_bits + bin_position[2:]
                for bit_idx in range(2,self.num_bits+2):
                    position_bit = bin_position[bit_idx]
                    if position_bit == '1':
                        position += [0.9]
                    else:
                        position += [0.1]
                self.particles += [Particle(self.num_bits,inertia,global_pull,noisy_movement, initial_position = position)]


############### INTERFACE FUNCTIONS ###############

    def isFinished(self):
        if self.verbose:
            print "Checking if Finished"
        return self.batch_counter >= self.num_batches

    def getNextStates(self):
        if self.verbose:
            print "Getting Next States"

        if self.batch_counter != 0:
            if self.verbose:
                print "Updating Local and Global Maximums..."

            global_optimum_state = self.top_states[0]
            for i in range(self.num_particles):
                particle = self.particles[i]
                observed_state = self.next_states[i]
                observed_value = self.explored_states[observed_state]
                particle.updateLocalBest(observed_state,observed_value)
                particle.updateGlobalBest(global_optimum_state)

            if self.verbose:
                print "Moving particles..."
            for particle in self.particles:
                particle.moveParticle()
        if self.verbose:
            print "Observing Particle Locations..."

        self.next_states = []
        for particle in self.particles:
            self.next_states += [particle.observeState()]

        self.batch_counter += 1
        return self.next_states

############### Child Class Helper Functions ###############

class Particle(object):
    # num_batches: how many batchs to iterate over
    # states_per_batch: how many states per batch
    def __init__(self, num_bits, inertia, global_pull, noisy_movement, initial_position = None,initial_velocity = None):
        self.num_bits = num_bits

        self.noisy_movement = noisy_movement

        self.inertia = inertia
        self.attraction = 1.0 - inertia

        self.global_pull = global_pull
        self.local_pull = 1.0 - global_pull

        self.position = QuantumPosition(self.num_bits,bit_probabilities=initial_position)
        self.velocity = QuantumVelocity(self.num_bits,velocities = initial_velocity)

        self.global_best = None
        self.local_best = None
        self.local_best_value = float("-Inf")

    def observeState(self):
        observed_position = self.position.observe()
        return observed_position.getState()

    def moveParticle(self):
        local_distance = self.local_best - self.position
        if self.noisy_movement:
            local_distance.addNoise()
        local_force = local_distance * self.local_pull

        global_distance = self.global_best - self.position
        if self.noisy_movement:
            global_distance.addNoise()
        global_force = global_distance * self.global_pull

        self.position = self.position + self.velocity
        self.velocity = self.velocity*self.inertia + (local_force + global_force)*self.attraction

    def updateLocalBest(self,local_state,local_value):
        if local_value >= self.local_best_value:
            self.local_best_value = local_value
            self.local_best = DiscretePosition(self.num_bits,state = local_state)

    def updateGlobalBest(self,global_state):
        self.global_best = DiscretePosition(self.num_bits,state=global_state)

# location of particle in probability space
class DiscretePosition(object):
    def __init__(self, num_bits, coordinates = None, state = None):
        self.num_bits = num_bits
        if coordinates is not None:
            self.coordinates = coordinates
        elif state is not None:
            self.coordinates = []
            for bit_idx in range(self.num_bits):
                # is bit a 1 or 0
                state_bit = (state>>bit_idx) & 1
                self.coordinates += [state_bit]
        else:
            self.coordinates = []
            for i in range(num_bits):
                self.coordinates += [random.randint(0,1)]

    def __sub__(self, other):
        if str(type(other)) == "<class 'Optimizer.ParticleSwarm.DiscretePosition'>":
            velocities = []
            for bit_idx in range(self.num_bits):
                e1 = self.coordinates[bit_idx]
                e2 = other.coordinates[bit_idx]
                # if they are the same, don't change anything
                if e1 == e2:
                    velocities += [0]
                # if they are different, indicate directional difference to equal e1
                else:
                    velocities += [2*(e1-0.5)]
            # return a quantum velocity indicating the desired direction
            return QuantumVelocity(self.num_bits,velocities=velocities)
        elif str(type(other)) == "<class 'Optimizer.ParticleSwarm.QuantumPosition'>":
            return other.__sub__(self)*-1
        else:
            exception_name = "Subtraction not defined for Type DiscretePosition and Type " + str(type(other))
            raise Exception(exception_name)

    def getState(self):
        state = 0
        for bit_idx in range(len(self.coordinates)):
            bit_value = self.coordinates[bit_idx]
            state += bit_value * 2**(bit_idx)
        return state

class QuantumPosition(object):
    def __init__(self, num_bits, bit_probabilities = None):
        self.num_bits = num_bits
        if bit_probabilities is None:
            self.bit_probabilities = []
            for i in range(self.num_bits):
                self.bit_probabilities += [random.random()]
        else:
            self.bit_probabilities = bit_probabilities

    def observe(self):
        bit_states = []
        for bit_probability in self.bit_probabilities:
            random_draw = random.random()
            if random_draw <= bit_probability:
                bit_states += [1]
            else:
                bit_states += [0]
        return DiscretePosition(self.num_bits, coordinates = bit_states)

    # applies bayes rule using quantum velocity to change probability
    def __add__(self, other):
        if str(type(other)) != "<class 'Optimizer.ParticleSwarm.QuantumVelocity'>":
            exception_name = "Addition not defined for Type QuantumPosition and Type " + str(type(other))
            raise Exception(exception_name)
        new_bit_probabilities = []
        for bit_index in range(self.num_bits):
            p_on = self.bit_probabilities[bit_index]
            p_off = 1.0 - p_on
            p_velocity_given_on = 0.5*(1+other.velocities[bit_index])
            p_velocity_given_off = 1.0 - p_velocity_given_on

            p_velocity = p_velocity_given_on*p_on + p_velocity_given_off*p_off
            p_on_given_velocity = p_velocity_given_on*p_on/p_velocity
            new_bit_probabilities += [p_on_given_velocity]
        return QuantumPosition(self.num_bits, bit_probabilities= new_bit_probabilities)

    def __sub__(self, other):
        if str(type(other)) == "<class 'Optimizer.ParticleSwarm.DiscretePosition'>":
            velocities = []
            for bit_idx in range(self.num_bits):
                self_location = self.bit_probabilities[bit_idx]
                other_location = other.coordinates[bit_idx]
                difference = self_location - other_location;
                velocities += [difference]
            # return a quantum velocity indicating the desired direction
            return QuantumVelocity(self.num_bits,velocities=velocities)
        elif str(type(other)) == "<class 'Optimizer.ParticleSwarm.QuantumPosition'>":
            velocities = []
            for bit_idx in range(self.num_bits):
                self_location = self.bit_probabilities[bit_idx]
                other_location = other.bit_probabilities[bit_idx]
                difference = self_location - other_location;
                velocities += [difference]
            # return a quantum velocity indicating the desired direction
            return QuantumVelocity(self.num_bits,velocities=velocities)
        else:
            exception_name = "Subtraction not defined for Type DiscretePosition and Type " + str(type(other))
            raise Exception(exception_name)


# a quantum velocity is a number between -1 and 1
# this represents how much weight to adjust a quantum position
# using bayes rules
# -1 implies that the quantum position P(0) will become 1
#  1 implies that the quantum position P(1) will become 1
# thus we set a max speed to ensure this never happens
class QuantumVelocity(object):
    def __init__(self, num_bits, velocities = None):
        self.max_speed = 1 - 10**(-6)
        self.num_bits = num_bits
        if velocities is None:
            self.velocities = []
            for i in range(num_bits):
                # random number between -1 and 1
                velocity = 2*(random.random()-0.5)
                #cap speed
                self.velocities += [velocity]
            self.velocities = self.bound_velocities(self.velocities)
        else:
            # bound velocity to max_seed
            velocities = self.bound_velocities(velocities)
            self.velocities = velocities

    def __add__(self, other):
        if str(type(other)) != "<class 'Optimizer.ParticleSwarm.QuantumVelocity'>":
            exception_name = "Addition not defined for Type DiscreteVelocity and Type " + str(type(other))
            raise Exception(exception_name)
        if self.num_bits != other.num_bits:
            raise Exception("Cannot add velocities with different dimensions")
        new_v = []
        for bit_idx in range(self.num_bits):
            v1 = self.velocities[bit_idx]
            v2 = other.velocities[bit_idx]
            v3 = v1 + v2
            new_v += [v3]
        new_v = self.bound_velocities(new_v)
        return QuantumVelocity(self.num_bits,velocities = new_v)

    def __mul__(self, other):
        if str(type(other)) in ["<type 'int'>","<type 'float'>"]:
            new_velocities = []
            for velocity in self.velocities:
                new_v = other * velocity
                new_velocities += [new_v]
            # cap to between -self.max_speed and self.max_speed
            new_v = self.bound_velocities(new_velocities)
            return QuantumVelocity(self.num_bits,velocities = new_velocities)
        else:
            exception_name = "Multiplication not defined for Type DiscreteVelocity and Type " + str(type(other))
            raise Exception(exception_name)
        pass

    def addNoise(self):
        for i in range(len(self.velocities)):
            ideal_v = self.velocities[i]
            self.velocities[i] = ideal_v * random.random()

    def bound_velocities(self, velocities):
        velocities = [min(self.max_speed,v) for v in velocities]
        velocities = [max(-1*self.max_speed,v) for v in velocities]
        return velocities
